# function to make dfm with string of compounds you provide
make.dfm <- function(data, text_col, doc_col, to_compound,
                     stem = TRUE){
  require(quanteda)
  
  message("Making corpus")
  corp <- quanteda::corpus(data, text_field= text_col, docid_field= doc_col)
  
  message("Setting compounds...")
  
  compounds <- quanteda::phrase(
    c("white house","fox news","united states",
      "united_nations",
      "social media","climate change","right wing",
      to_compound)
  )
  
  message("Tokenizing...")
  tok <- quanteda::tokens(corp,
                          remove_punct=T,
                          remove_twitter=T, 
                          remove_url=T) %>%
    quanteda::tokens_remove("\\p{Z}", valuetype = "regex") %>%
    quanteda::tokens_remove("via", valuetype = "regex") %>%
    
    # remove stop words
    #quanteda::tokens_remove(quanteda::stopwords(source = "snowball")) %>%
    
    # concatenate common bigrams
    quanteda::tokens_compound(pattern = compounds)
  
  if(stem == TRUE){
    # stem
    tok <- tok %>% quanteda::tokens_wordstem()
  }
  
  message("Casting to dfm...")
  
  # send tokens to dfm
  bags_dfm <-  quanteda::dfm(tok,  verbose=T,tolower = T)
  
  return(bags_dfm)
}

# function to return tf-idf matrix
# function to make dfm with string of compounds you provide
make.tfidf <- function(data, text_col, doc_col, to_compound,
                       ngrams = NULL,
                       stem = TRUE, keep_pct = .25){
  require(quanteda)
  
  message("Making corpus")
  corp <- quanteda::corpus(data, text_field= text_col, docid_field= doc_col)
  
  message("Setting compounds...")
  
  compounds <- quanteda::phrase(
    c("white house","fox news","united states",
      "united_nations",
      "social media","climate change","right wing",
      to_compound)
  )
  
  message("Tokenizing...")
  tok <- quanteda::tokens(corp,
                          remove_punct=T,
                          remove_twitter=T, 
                          remove_url=T, include_docvars = TRUE) %>%
    quanteda::tokens_remove("\\p{Z}", valuetype = "regex") %>%
    quanteda::tokens_remove("via", valuetype = "regex") %>%
    
    # remove stop words
    quanteda::tokens_remove(quanteda::stopwords(source = "snowball")) %>%
    
    # concatenate common bigrams
    quanteda::tokens_compound(pattern = compounds)
  
  if(stem == TRUE){
    # stem
    tok <- tok %>% quanteda::tokens_wordstem()
  }
  
  if(!is.null(ngrams)){
    message(paste0("ngramming to n = ", ngrams, "..."))
    
    # ngram
    tok <- tok %>%  quanteda::tokens_ngrams(ngrams)
  }
  
  message("Casting to dfm...")
  
  # send tokens to dfm
  dfm.full <-  quanteda::dfm(tok,  verbose=T,tolower = T)
  
  # match indices to users to roll up
  users <- data$id[match(rownames(dfm.full),data$index)]
  
  # roll up by member
  dfm.g <- quanteda::dfm_group(dfm.full, groups = users)
  
  # calculate inverse document frequency
  idf <- dfm_tfidf(dfm.g, 
                   scheme_tf = "count", 
                   scheme_df = "inverse", 
                   base = 10,
                   force = FALSE)
  
  # get importance for each token
  global_importance <- sort(colSums(idf), decreasing = TRUE)
  
  # get top 10k important tokens
  keeps <- global_importance[1:(floor(length(global_importance) * keep_pct))]
  
  # subset to important tokens
  idf_sub <- idf[,which(colnames(idf) %in% names(keeps))]
  
  return(idf_sub)
}

# pivot scaling stuff

parrot_routine <- function(dat){
  require(stm)
  require(parrot)
  require(data.table)
  require(quanteda)
  
  message("Processing... \n")
  processed <- textProcessor(
    documents = dat$lowtext,
    metadata = data.frame(dat),
    removestopwords=T, lowercase=T, stem=TRUE, removenumbers = TRUE,
    removepunctuation = TRUE
  )
  
  message("Prepping... \n")
  out <- prepDocuments(
    processed$documents, processed$vocab, processed$meta
  )
  
  message("Extracting embeddings.... \n")
  tdm <- doc_to_tdm(out)
  
  message("Calculating word scores... \n")
  scores <- scale_text(
    meta=out$meta,
    tdm=tdm,
    constrain_outliers = TRUE,
    pivot = 2
    ##    embeddings=embeddings[["meta"]], ## embeddings have little effect
    ##    on output -- if used, consider setting pivot lower (e.g. pivot = 1/2)
  )
  
  return(scores)
}

get_keywords_custom <- function (scores, n_dimensions, n_words = 15, stretch = 3, capture_output = FALSE, 
                                 pivots_only = TRUE, topic) 
{
  all_keywords <- list()
  if (stretch%%2 != 1) 
    stop("Please enter odd integer for \"stretch\"")
  for (i in if (length(n_dimensions) == 1) {
    1:n_dimensions
  }
  else {
    n_dimensions
  }) {
    general_keywords <- scores$vocab[order(scores$pivot_scores[, 
                                                               i + 1] * sqrt(rowSums(scores$pivot_scores[, -1]^2)), 
                                           decreasing = TRUE)]
    specific_keywords <- scores$vocab[order(scores$word_scores[, 
                                                               i + 1]^(stretch) * sqrt(rowSums(scores$pivot_scores[, 
                                                                                                                   -1]^2)), decreasing = TRUE)]
    if (pivots_only) {
      keywords <- data.frame(head(rev(general_keywords), 
                                  n = n_words), head(general_keywords, n = n_words))
      names(keywords) <- c("pivots (-)", "(+) pivots")
    }
    else {
      keywords <- data.frame(head(rev(specific_keywords), 
                                  n = n_words), head(rev(general_keywords), n = n_words), 
                             head(general_keywords, n = n_words), head(specific_keywords, 
                                                                       n = n_words))
      names(keywords) <- c("scores (-)", "pivots (-)", 
                           "(+) pivots", "(+) scores")
    }
    if (capture_output) {
      all_keywords[[paste0("D", i)]] <- keywords
    }
    else {
      if (!requireNamespace("knitr", quietly = TRUE)) {
        cat("\nDimension", i, "keywords\n\n")
        print(keywords, row.names = F)
        cat("\n")
      }
      else {
        print(knitr::kable(keywords, align = "c", format = "pandoc", 
                           caption = paste("Dimension", i, "keywords: ", topic)))
        cat("\n")
      }
    }
  }
  if (capture_output) {
    return(all_keywords)
  }
}

# function to plot pivot scaling results
plot_keywords_custom <- function (scores, x_dimension = 1, y_dimension = 2, q_cutoff = 0.9, 
                                  plot_density = FALSE, unstretch = FALSE, color = FALSE,
                                  subjname = "subject"){
  if (unstretch) {
    scores$word_scores <- sweep(scores$word_scores, 1, 
                                sqrt(rowSums((scores$importance[-1] * scores$pivot_scores[, -1])^2)) + 1, `/`)
  }
  word_scores <- data.frame(scores$word_scores)
  word_counts <- scores$word_counts
  above_cutoff <- word_counts > quantile(word_counts, q_cutoff)
  x_dimension <- x_dimension + 1
  y_dimension <- y_dimension + 1
  if (color & !("color" %in% names(scores))) {
    scores$color <- factor(kmeans(scores$word_scores[, 2:11], 
                                  centers = 5)$cluster)
  }
  if (!color) {
    g <- ggplot2::ggplot() + 
      ggplot2::geom_text(data = word_scores[above_cutoff,], 
                         ggplot2::aes(x = word_scores[above_cutoff, x_dimension], 
                                      y = word_scores[above_cutoff, y_dimension], 
                                      label = scores$vocab[above_cutoff])) + 
      ggplot2::xlab(paste("Dimension:", x_dimension - 1)) + 
      ggplot2::ylab(paste("Dimension:", y_dimension - 1)) + 
      ggplot2::guides(size = F) + ggplot2::theme_classic() + 
      ggplot2::xlim(-max(abs(word_scores[above_cutoff,  x_dimension])), 
                    max(abs(word_scores[above_cutoff, x_dimension]))) + 
      ggplot2::ylim(-max(abs(word_scores[above_cutoff, y_dimension])), 
                    max(abs(word_scores[above_cutoff, y_dimension])))+
      ggplot2::ggtitle(paste0("Embedded dimensions in pundits' tweets about ", subjname),
                       subtitle = paste0("Top ", 100*(1-q_cutoff), "% of words shown"))
  }
  else {
    g <- ggplot2::ggplot() + 
      ggplot2::geom_text(data = word_scores[above_cutoff, ], 
                         ggplot2::aes(x = word_scores[above_cutoff, x_dimension], 
                                      y = word_scores[above_cutoff, y_dimension], 
                                      label = scores$vocab[above_cutoff], 
                                      color = scores$color[above_cutoff])) + 
      ggplot2::xlab(paste("Dimension:", x_dimension - 1)) + 
      ggplot2::ylab(paste("Dimension:", y_dimension - 1)) + 
      ggplot2::guides(size = F, color = F) + ggplot2::theme_classic() + 
      ggplot2::xlim(-max(abs(word_scores[above_cutoff, 
                                         x_dimension])), 
                    max(abs(word_scores[above_cutoff, x_dimension]))) + 
      ggplot2::ylim(-max(abs(word_scores[above_cutoff, y_dimension])), 
                    max(abs(word_scores[above_cutoff,  y_dimension])))+
      ggplot2::ggtitle(paste0("Embedded dimensions in pundits' tweets about ", subjname),
                       subtitle = paste0("Top ", 100*(1-q_cutoff), "% of words shown"))
  }
  if (!plot_density) {
    return(g)
  }
  else {
    gridExtra::grid.arrange(g, ggplot2::ggplot() + 
                              ggplot2::geom_density(ggplot2::aes(x = word_scores[, x_dimension])) +
                              ggplot2::xlab(paste("Dimension:", x_dimension - 1)) + 
                              ggplot2::theme_classic(), 
                            ggplot2::ggplot() + 
                              ggplot2::geom_density(ggplot2::aes(x = word_scores[,  y_dimension])) + 
                              ggplot2::xlab(paste("Dimension",  y_dimension - 1)) + 
                              ggplot2::theme_classic(), 
                            layout_matrix = rbind(c(1, 1, 2), c(1, 1, 3)))
  }
}

# function to run multinomial inverse regression from input of subsetted data
text_to_space <- function(dat, outvar,
                          text_col, doc_col,
                          to_compound,
                          weightvar = NULL,
                          term_min = .01, 
                          stem = TRUE,
                          ngrams = NULL,
                          seed = 11111){
  require(tidyverse)
  require(quanteda)
  require(distrom)
  require(textir)
  
  set.seed(seed)
  
  message("Making corpus")
  corp <- quanteda::corpus(dat, text_field= text_col, docid_field= doc_col)
  
  message("Setting compounds...")
  
  compounds <- quanteda::phrase(
    c("white house","fox news","united states",
      "united_nations",
      "social media","climate change","right wing",
      to_compound)
  )
  
  message("Tokenizing...")
  
  # tokenize/preprocess
  tok <- quanteda::tokens(corp,
                          remove_punct=T,
                          remove_twitter=T, 
                          remove_numbers = T,
                          remove_url=T) %>%
    quanteda::tokens_remove("\\p{Z}", valuetype = "regex") %>%
    quanteda::tokens_remove("via", valuetype = "regex") %>%
    
    # remove stop words
    quanteda::tokens_remove(quanteda::stopwords(source = "snowball")) %>%
    
    # concatenate common bigrams
    quanteda::tokens_compound(pattern = compounds)
  
  if(stem == TRUE){
    message("Stemming...")
    # stem
    tok <- tok %>% quanteda::tokens_wordstem()
  }
  if(!is.null(ngrams)){
    message(paste0("ngramming to n = ", ngrams, "..."))
    
    # ngram
    tok <- tok %>%  quanteda::tokens_ngrams(1:ngrams)
  }
  
  # make full document frequency matrix
  dfm.full <-  dfm(tok,  verbose=T,tolower = T)
  
  # discard very frequent/infrequent terms
  message("Trimming dfm...")
  
  dfm.lim <- dfm_trim(
    dfm.full, 
    min_docfreq = term_min, 
    max_docfreq = .5,
    docfreq_type = "prop")
  
  # if weighting, weight
  # right now this assumes you're either weighting by retweet or follower count, 
  # which should be logged
  if(!is.null(weightvar)){
    message("Weighting dfm...")
    for(i in rownames(dfm.lim)){
      weight <- dat[which(as.character(dat$index) == i),which(names(dat) == weightvar)] %>% as.numeric()
      
      dfm.lim[which(rownames(dfm.lim) == i),] <- dfm.lim[which(rownames(dfm.lim) == i),]*log(weight + 1)
    }
  }
  
  message("Prepping for MIR...")
  
  # put back in matrix
  X <- as.matrix(dfm.lim)
  
  # match indices to users to roll up
  users <- dat$id[match(rownames(dfm.lim),dat$index)]
  
  # roll up by user
  tmp <- by(X, INDICES=users, colSums, simplify = T)
  
  # cast independent variables: term frequencies by user
  X <- do.call(rbind, tmp)
  
  # make outcome variable
  Y<- dat[match(rownames(X), dat$id), outvar]
  
  # discard rows with no token usage
  # shouldn't be the case when aggregating by member
  drops <- which(rowSums(X)==0)
  
  if(length(drops) > 0){
    Y <- Y[-drops]
    X <- X[-drops, ]
  }
  
  message("Modeling...")
  
  # run model
  mod1 <- dmr(covars=Y, counts=X, cl=NULL, verb=2, gamma=0, nlambda=100)
  
  message("Prepping data...")
  
  # store coefs
  c1 <- coef(mod1)
  
  # make coefficient frame 
  tmp <- as.data.frame(as.matrix(t(c1)))
  names(tmp) <- c("intercept", "coef")
  tmp$word <- rownames(tmp)
  
  # put term frequencies back on coefficient data frame
  xm <- data.frame(word = colnames(dfm.lim),
                   freq = colSums(dfm.lim))
  xm$word <- as.character(xm$word)
  
  tmp <- tmp %>% left_join(xm, by = "word")
  
  # project in-sample
  insample_proj <- textir::srproj(mod1, X)
  
  message("Done!")
  
  return(list(mod = mod1,
              preds = insample_proj,
              data = tmp))
}


# function to do out of sample prediction with randomForest
text_to_pred_traintest <- function(dat, outvar,
                                   text_col, doc_col,
                                   to_compound,
                                   weightvar = NULL,
                                   use_dates = TRUE,
                                   store_importance = "none",
                                   term_min = .01, 
                                   trainshare = .7,
                                   stem = TRUE,
                                   ngrams = NULL,
                                   seed = 11111){
  require(tidyverse)
  require(quanteda)
  require(distrom)
  require(textir)
  
  set.seed(seed)
  
  message("Making corpus")
  corp <- quanteda::corpus(dat, text_field= text_col, docid_field= doc_col)
  
  message("Setting compounds...")
  
  compounds <- quanteda::phrase(
    c("white house","fox news","united states",
      "united_nations",
      "social media","climate change","right wing",
      to_compound)
  )
  
  message("Tokenizing...")
  
  # tokenize/preprocess
  tok <- quanteda::tokens(corp,
                          remove_punct=T,
                          remove_twitter=T, 
                          remove_numbers = T,
                          remove_url=T) %>%
    quanteda::tokens_remove("\\p{Z}", valuetype = "regex") %>%
    quanteda::tokens_remove("via", valuetype = "regex") %>%
    
    # remove stop words
    quanteda::tokens_remove(quanteda::stopwords(source = "snowball")) %>%
    
    # concatenate common bigrams
    quanteda::tokens_compound(pattern = compounds)
  
  if(stem == TRUE){
    message("Stemming...")
    # stem
    tok <- tok %>% quanteda::tokens_wordstem()
  }
  if(!is.null(ngrams)){
    message(paste0("ngramming to n = ", ngrams, "..."))
    
    # ngram
    tok <- tok %>%  quanteda::tokens_ngrams(1:ngrams)
  }
  
  # make full document frequency matrix
  dfm.full <-  dfm(tok,  verbose=T,tolower = T)
  
  # discard very frequent/infrequent terms
  message("Trimming dfm...")
  
  dfm.lim <- dfm_trim(
    dfm.full, 
    min_docfreq = term_min, 
    max_docfreq = .5,
    docfreq_type = "prop")
  
  # if weighting, weight
  # right now this assumes you're either weighting by retweet or follower count, 
  # which should be logged
  if(!is.null(weightvar)){
    message("Weighting dfm...")
    for(i in rownames(dfm.lim)){
      weight <- dat[which(as.character(dat$index) == i),which(names(dat) == weightvar)] %>% as.numeric()
      
      dfm.lim[which(rownames(dfm.lim) == i),] <- dfm.lim[which(rownames(dfm.lim) == i),]*log(weight + 1)
    }
  }
  
  message("Prepping for MIR...")
  
  # put back in matrix
  X <- as.matrix(dfm.lim)
  
  if(use_dates == TRUE){
    # make days before/after jan 1 2020 vector and append to text features
    dates <- dat$day_relative_1120[match(rownames(X), dat$index)]
    
    X <- cbind(X, dates)
  }
  
  # make outcome vector
  Y<- dat[match(rownames(X), dat$index), outvar] %>% pull(outvar)
  
  # discard empty docs
  drops <- which(rowSums(X)==0)
  
  if(length(drops) > 0){
    Y <- Y[-as.numeric(drops)]
    X <- X[-as.numeric(drops), ]
  }
  
  # make train/test splits
  which_train <- sample(1:length(Y), floor(length(Y)*trainshare), replace = F)
  
  X_train <- X[which_train,]
  Y_train <- Y[which_train]
  
  X_test <- X[-which_train,]
  Y_test <- Y[-which_train]
  
  message("Modeling...")
  
  # bind training outcome/features
  trainbind <- cbind(Y_train, X_train)
  
  # model
  rf1 <- ranger::ranger(data= trainbind,
                        dependent.variable.name = "Y_train",
                        classification = TRUE,
                        probability = TRUE,
                        importance = store_importance,
                        verbose = TRUE)
  
  # generate predictions
  pt <- dat %>% filter(index %in% rownames(X_test)) %>%
    # take first column of predicted probabilities
    mutate(preds = predict(rf1, X_test)$predictions[,1]) %>%
    mutate(classpred = ifelse(preds >= .5, "R", "D")) %>%
    mutate(correct = (classpred == caucus))
  
  # if the classes got flipped in RF prediction, flip back
  if(mean(pt$correct) < .5){
    pt <- pt %>% mutate(preds = 1-preds) %>%
      mutate(classpred = ifelse(preds >= .5, "R", "D")) %>%
      mutate(correct = (classpred == caucus))
  }
  
  message("Done!")
  
  return(list(mod = rf1,
              oos_predictions = pt))
}

# function to visualize rf variable importance
rf_to_viz <- function(rf_mod, out, xvars, sub, classification = FALSE, shrink = NULL){
  
  # get variable importance into data frame
  vimp_lpc <- 
    rf_mod$variable.importance %>%
    data.frame()
  
  names(vimp_lpc) <- "importance"
  vimp_lpc$var <- rownames(vimp_lpc)
  
  if(!is.null(shrink)){
    vimp_lpc <- vimp_lpc %>% filter(importance > shrink)
  }
  
  if(classification == FALSE){
    vimp_lpc_plot <- 
      vimp_lpc %>%
      ggplot(aes(x=reorder(var,importance), 
                 y=importance))+ 
      geom_bar(stat="identity", 
               position="dodge")+ 
      coord_flip()+
      ylab("Variable importance (permutation)")+
      xlab(xvars)+
      labs(title = paste0(out),
           subtitle = sub,
           caption = paste0("OOB R-Squared: ", round(rf_mod$r.squared, 3)))+
      guides(fill=F)+
      theme_dfp()
  }
  
  if(classification == TRUE){
    vimp_lpc_plot <- 
      vimp_lpc %>%
      ggplot(aes(x=reorder(var,importance), 
                 y=importance))+ 
      geom_bar(stat="identity", 
               position="dodge")+ 
      coord_flip()+
      ylab("Variable importance (permutation)")+
      xlab(xvars)+
      labs(title = paste0(out),
           subtitle = sub,
           caption = paste0("OOB Prediction Error: ", round(rf_mod$prediction.error, 3)))+
      guides(fill=F)+
      theme_dfp()
  }
  return(vimp_lpc_plot)
}


# function to run cross-validation on out of sample prediction
  # function takes:
  # data (tibble), 
  # character strings for outcome variable, column with text, column with docid, 
  # character vector of terms to compound
  # logical for whether to stem (default is true)
  # numeric vector for number of n-grams to tokenize to (default is null, but we use in the final model)
  # charater string for date column, if including dates (default is null, but we use in the final model)
  # character strings for volume and weight variables, if including (default is null, we don't use in the final model)
  # character string for type of variable importance to store in ranger (default is "none")
  # number for minimum pct of docs a term has to appear in to be included in dfm (default is 1%)
  # number for proportion of docs to be included in training set (default is 70%)
  # number for number of folds in cross-validation (default is 15 -- five runs on three cores)
  # seed  
text_to_pred_multifolds <- function(dat, outvar,
                                    text_col, doc_col,
                                    to_compound,
                                    stem = TRUE,
                                    ngrams = NULL,
                                    date_var = NULL,
                                    volume_var = NULL,
                                    weightvar = NULL,
                                    store_importance = "none",
                                    term_min = .01, 
                                    trainshare = .7,
                                    n_folds = 15,
                                    seed = 11111){
  require(tidyverse)
  require(quanteda)
  require(distrom)
  require(textir)
  
  require(foreach)
  require(doMC)
  registerDoMC(cores = detectCores() - 1)
  
  set.seed(seed)
  
  # define index
  dat$index <- dat %>% pull(doc_col)
  
  message("Making corpus")
  corp <- quanteda::corpus(dat, text_field= text_col, docid_field= "index")
  
  message("Setting compounds...")
  
  compounds <- quanteda::phrase(
    c("white house","fox news","united states",
      "united nations",
      "social media","climate change","right wing",
      to_compound)
  )
  
  message("Tokenizing...")
  
  # tokenize/preprocess/remove extraneous twitter stuff
  tok <- quanteda::tokens(corp,
                          remove_punct=T,
                          remove_twitter=T, 
                          remove_numbers = T,
                          remove_url=T) %>%
    quanteda::tokens_remove("\\p{Z}", valuetype = "regex") %>%
    quanteda::tokens_remove("via", valuetype = "regex") %>%
    
    # remove stop words
    quanteda::tokens_remove(quanteda::stopwords(source = "snowball")) %>%
    
    # concatenate common bigrams
    quanteda::tokens_compound(pattern = compounds)
  
  if(stem == TRUE){
    message("Stemming...")
    tok <- tok %>% quanteda::tokens_wordstem()
  }
  if(!is.null(ngrams)){
    message(paste0("ngramming to n = ", ngrams, "..."))
    
    # ngram
    tok <- tok %>%  quanteda::tokens_ngrams(ngrams)
  }
  
  # make full document frequency matrix
  dfm.full <-  dfm(tok,  verbose=T,tolower = T)
  
  # discard very frequent/infrequent terms
  message("Trimming dfm...")
  
  dfm.lim <- dfm_trim(
    dfm.full, 
    min_docfreq = term_min, 
    max_docfreq = .5,
    docfreq_type = "prop")
  
  message(paste0("Retained ", ncol(dfm.lim), " tokens for analysis..."))
  
  if(!is.null(weightvar)){
    message("Weighting dfm...")
    for(i in rownames(dfm.lim)){
      weight <- dat[which(as.character(dat$index) == i),which(names(dat) == weightvar)] %>% as.numeric()
      
      dfm.lim[which(rownames(dfm.lim) == i),] <- dfm.lim[which(rownames(dfm.lim) == i),]*log(weight + 1)
    }
  }
  
  message("Prepping IV and DV...")
  
  # put back in matrix
  X <- as.matrix(dfm.lim)
  
  if(!is.null(date_var)){
    # make date variable and append to text features
    dates <- dat[match(rownames(X), dat$index), date_var]
    
    X <- cbind(X, dates)
  }
  
  if(!is.null(volume_var)){
    # make volume variable and append to text features
    volume <- dat[match(rownames(X), dat$index), volume_var]
    
    X <- cbind(X, volume)
  }
  
  # make outcome vector
  Y<- dat[match(rownames(X), dat$index), outvar] %>% pull(outvar)
  
  # discard empty docs
  drops <- which(rowSums(X)==0)
  
  if(length(drops) > 0){
    Y <- Y[-as.numeric(drops)]
    X <- X[-as.numeric(drops), ]
  }
  
  # make train/test splits
  pred_routine <- foreach::foreach(i = 1:n_folds) %dopar% {
    
    message(paste0("Splitting train and test sets for fold number", i,"..."))
    
    which_train <- sample(1:length(Y), floor(length(Y)*trainshare), replace = F)
    
    # define IVs and DV for training and test sets
    X_train <- X[which_train,]
    Y_train <- Y[which_train]
    
    X_test <- X[-which_train,]
    Y_test <- Y[-which_train]
    
    # bind training outcome/features
    trainbind <- cbind(Y_train, X_train)
    
    message(paste0("Modeling fold number", i,"..."))
    
    # model
    rf1 <- ranger::ranger(data= trainbind,
                          dependent.variable.name = "Y_train",
                          classification = TRUE,
                          probability = TRUE,
                          importance = store_importance,
                          verbose = FALSE)
    
    message(paste0("Predicting test set for fold number ", i,"..."))
    
    # generate predictions
    pt <- dat %>% filter(index %in% rownames(X_test)) %>%
      
      # take first column of predicted probabilities
      mutate(preds = predict(rf1, X_test)$predictions[,1]) %>%
      mutate(classpred = ifelse(preds >= .5, "R", "D")) %>%
      mutate(correct = (classpred == caucus),
             fold = i)
    
    # if the classes got flipped in RF prediction, flip back
    if(mean(pt$correct) < .5){
      pt <- pt %>% mutate(preds = 1-preds) %>%
        mutate(classpred = ifelse(preds >= .5, "R", "D")) %>%
        mutate(correct = (classpred == caucus))
    }
    
    # take measures of fit
      # overall accuracy, area under roc curve, and area under precision-recall curve for current fold
    aucroc <- PRROC::roc.curve(scores.class0 = as.numeric(pt$caucus == "R"),
                               weights.class0  = pt$preds)$auc
    
    auprc <- PRROC::pr.curve(scores.class0 = as.numeric(pt$caucus == "R"),
                             weights.class0  = pt$preds)$auc.integral
    
    outmat <- data.frame(acc = mean(pt$correct),
                         roc = aucroc,
                         prc = auprc,
                         fold = i)
    
    # if using variable importance, pull that out too
    if(!store_importance == "none"){
      imp <- rf1$variable.importance
    }else{
      imp <- NULL
    }
    
    return(list(fit_results = outmat,
                test_set_data = pt,
                importance = imp))
  }
  
  # stack measures of fit across folds
  outmat <- bind_rows(lapply(pred_routine, function(x){
    return(x$fit_results)
  }))
  
  test_sets <- bind_rows(lapply(pred_routine, function(x){
    return(x$test_set_data)
  }))
  
  # if storing importance, pull that out
  if(!store_importance == "none"){
    imp <- lapply(pred_routine, function(x){
      x$importance
    })
  }else{
    imp <- NULL
  }
  
  message("Done!")
  return(list(fit_results = outmat,
              preds = test_sets,
              importance = imp))
}

# function to run cross-validation on out of sample prediction
MIR_to_pred_multifolds <- function(dat, outvar,
                                   text_col, doc_col,
                                   to_compound,
                                   stem = TRUE,
                                   ngrams = NULL,
                                   date_var = NULL,
                                   volume_var = NULL,
                                   weightvar = NULL,
                                   term_min = .01, 
                                   trainshare = .7,
                                   n_folds = 15,
                                   seed = 11111){
  require(tidyverse)
  require(quanteda)
  require(distrom)
  require(textir)
  
  require(foreach)
  require(doMC)
  registerDoMC(cores = detectCores() - 1)
  
  set.seed(seed)
  
  # define index
  dat$index <- dat %>% pull(doc_col) %>% as.character()
  
  message("Making corpus")
  corp <- quanteda::corpus(dat, text_field= text_col, docid_field= "index")
  
  message("Setting compounds...")
  
  compounds <- quanteda::phrase(
    c("white house","fox news","united states",
      "united_nations",
      "social media","climate change","right wing",
      to_compound)
  )
  
  message("Tokenizing...")
  
  # tokenize/preprocess
  tok <- quanteda::tokens(corp,
                          remove_punct=T,
                          remove_twitter=T, 
                          remove_numbers = T,
                          remove_url=T) %>%
    quanteda::tokens_remove("\\p{Z}", valuetype = "regex") %>%
    quanteda::tokens_remove("via", valuetype = "regex") %>%
    
    # remove stop words
    quanteda::tokens_remove(quanteda::stopwords(source = "snowball")) %>%
    
    # concatenate common bigrams
    quanteda::tokens_compound(pattern = compounds)
  
  if(stem == TRUE){
    message("Stemming...")
    # stem
    tok <- tok %>% quanteda::tokens_wordstem()
  }
  if(!is.null(ngrams)){
    message(paste0("ngramming to n = ", ngrams, "..."))
    
    # ngram
    tok <- tok %>%  quanteda::tokens_ngrams(ngrams)
  }
  
  # make full document frequency matrix
  dfm.full <-  dfm(tok,  verbose=T,tolower = T)
  
  # discard very frequent/infrequent terms
  message("Trimming dfm...")
  
  dfm.lim <- dfm_trim(
    dfm.full, 
    min_docfreq = term_min, 
    max_docfreq = .5,
    docfreq_type = "prop")
  
  message(paste0("Retained ", ncol(dfm.lim), " tokens for analysis..."))
  
  if(!is.null(weightvar)){
    message("Weighting dfm...")
    for(i in rownames(dfm.lim)){
      weight <- dat[which(as.character(dat$index) == i),which(names(dat) == weightvar)] %>% as.numeric()
      
      dfm.lim[which(rownames(dfm.lim) == i),] <- dfm.lim[which(rownames(dfm.lim) == i),]*log(weight + 1)
    }
  }
  
  message("Prepping IV and DV...")
  
  # put back in matrix
  X <- as.matrix(dfm.lim)
  
  # discard empty docs
  drops <- which(rowSums(X)==0)
  
  # make outcome vector
  Y<- dat[match(rownames(X), dat$index), outvar] %>% pull(outvar)
  
  
  if(length(drops) > 0){
    Y <- Y[-as.numeric(drops)]
    X <- X[-as.numeric(drops), ]
  }
  
  # make train/test splits
  pred_routine <- foreach::foreach(i = 1:n_folds) %dopar% {
    
    message(paste0("Splitting train and test sets for fold number", i,"..."))
    
    # make train/test splits
    which_train <- sample(1:length(Y), floor(length(Y)*trainshare), replace = F)
    
    X_train <- X[which_train,]
    Y_train <- Y[which_train]
    
    X_test <- X[-which_train,]
    Y_test <- Y[-which_train]
    
    if(!is.null(date_var)){
      # make days before/after jan 1 2020 vector and append to text features
      dates <- dat[match(rownames(X), dat$index), date_var] %>% pull(1)
      
      dates_train <- dates[which_train]
      Y_train <- cbind(Y_train, dates_train)
      
      dates_test <-dates[-which_train]
      Y_test <- cbind(Y_test, dates_test)
    }
    
    if(!is.null(volume_var)){
      # make days before/after jan 1 2020 vector and append to text features
      volume <- dat[match(rownames(X), dat$index), volume_var] %>% pull(1)
      
      volume_train <- volume[which_train]
      Y_train <- cbind(Y_train, volume_train)
      
      volume_test <-volume[-which_train]
      Y_test <- cbind(Y_test, volume_test)
    }
    
    message(paste0("Modeling fold number", i,"..."))
    
    message("Modeling...")
    
    # run model
    mod1 <- dmr(covars=Y_train, counts=X_train, cl=NULL, verb=2, gamma=0, nlambda=100)
    
    message("Prepping data...")
    
    # store coefs
    c1 <- coef(mod1)
    
    # make coefficient frame 
    tmp <- as.data.frame(as.matrix(t(c1)))
    tmp$word <- rownames(tmp)
    
    # put term frequencies back on coefficient data frame
    xm <- data.frame(word = colnames(dfm.lim),
                     freq = colSums(dfm.lim))
    xm$word <- as.character(xm$word)
    
    tmp <- tmp %>% left_join(xm, by = "word")
    
    # project out of sample
    outsample_proj <- data.frame(textir::srproj(mod1, X_test))
    outsample_proj$index <- rownames(outsample_proj)
    
    message(paste0("Predicting test set for fold number ", i,"..."))
    
    # generate predictions
    pt <- dat %>% 
      inner_join(outsample_proj, by = "index") %>%
      
      # take first column of predicted probabilities
      mutate(classpred = ifelse(Y_train >= 0, "R", "D")) %>%
      mutate(correct = (classpred == caucus),
             fold = i)
    
    outmat <- data.frame(acc = mean(pt$correct),
                         fold = i)
    
    return(list(fit_results = outmat,
                test_set_data = pt))
  }
  
  outmat <- bind_rows(lapply(pred_routine, function(x){
    return(x$fit_results)
  }))
  
  test_sets <- bind_rows(lapply(pred_routine, function(x){
    return(x$test_set_data)
  }))
  
  message("Done!")
  return(list(fit_results = outmat,
              preds = test_sets))
}


cascade.seconds <- function(dat, 
                            text_col, doc_col, 
                            to_compound,
                            ngrams = NULL,
                            stem = FALSE,
                            term_min = .01,
                            pcut = .01, 
                            seed = 12345){
  require(NetworkInference)
  require(quanteda)
  require(data.table)
  require(tidyverse)
  
  set.seed(seed)
  
  # get count of tweets
  ntweets <- dat %>% 
    group_by(id) %>% 
    summarise(n = n())
  
  message("Making corpus")
  corp <- quanteda::corpus(dat, text_field= text_col, docid_field= doc_col)
  
  message("Setting compounds...")
  
  compounds <- quanteda::phrase(
    c("white house","fox news","united states",
      "united_nations",
      "social media","climate change","right wing",
      to_compound)
  )
  
  message("Tokenizing...")
  
  # tokenize/preprocess
  tok <- quanteda::tokens(corp,
                          remove_punct=T,
                          remove_twitter=T, 
                          remove_numbers = T,
                          remove_url=T) %>%
    quanteda::tokens_remove("\\p{Z}", valuetype = "regex") %>%
    quanteda::tokens_remove("via", valuetype = "regex") %>%
    
    # remove stop words
    quanteda::tokens_remove(quanteda::stopwords(source = "snowball")) %>%
    
    # concatenate common bigrams
    quanteda::tokens_compound(pattern = compounds)
  
  if(stem == TRUE){
    message("Stemming...")
    # stem
    tok <- tok %>% quanteda::tokens_wordstem()
  }
  if(!is.null(ngrams)){
    message(paste0("ngramming to n = ", ngrams, "..."))
    
    # ngram
    tok <- tok %>%  quanteda::tokens_ngrams(ngrams)
  }
  
  # make full document frequency matrix
  dfm.full <-  dfm(tok,  verbose=T,tolower = T)
  
  # discard very frequent/infrequent terms
  message("Trimming dfm...")
  
  dfm.lim <- dfm_trim(
    dfm.full, 
    min_docfreq = term_min, 
    max_docfreq = .5,
    docfreq_type = "prop")
  
  # put in matrix and store dimensions
  dfm_mat <- as.matrix(dfm.lim)
  finaldims <- dim(dfm_mat)
  
  # get timestamps
  dts <- dat$created_at[match(rownames(dfm_mat), dat$index)]
  
  # replace dfm cells of 0 with NA's
  dfm_mat[as.numeric(as.character(dfm_mat))==0]<-NA
  
  # replace non-NAs with timestamp for when the term was used
  for(i in 1:nrow(dfm_mat)){
    dfm_mat[i,] <- ifelse(is.na(dfm_mat[i,]), NA,dts[i])
  }
  
  # remove columns with all NA's
  dfm_mat <- dfm_mat[,colSums(is.na(dfm_mat))<nrow(dfm_mat)] #remove rows with all NA
  
  # add names to matrix
  dfm_mat <- as.matrix(cbind.data.frame(user=corp$documents$id,
                                        dfm_mat))
  
  
  # FIND FIRST USAGE OF TERM BY EACH MoC
  DTM<-as.data.frame(
    matrix(ncol=ncol(dfm_mat),
           nrow=length(unique(dfm_mat[,1]))))
  
  colnames(DTM)<-colnames(dfm_mat)
  DTM[,1]<-unique(dfm_mat[,1])
  
  for(i in 2:ncol(dfm_mat)){
    d = aggregate(dfm_mat[,i] ~ dfm_mat[,1],FUN=min)
    DTM[,i]<- d[,2][match(DTM[,1],d[,1])]
  }
  
  # rename rows and remove actor name column
  rownames(DTM)<-DTM[,1]
  DTM$user<-NULL
  
  
  #change class from numeric to date
  DTM[ , 1:ncol(DTM)] <- lapply(DTM[ , 1:ncol(DTM)], function(x) lubridate::as_datetime(as.numeric(x)))
  
  # include only terms used by at least 3 members (netinf tracks diffusion of tokens,
  # which can't diffuse if used by only 1 member; 3+ is a slightly more stringent indicator of importance)
  used <- lapply(DTM,function(col) length(unique(col))>=3)
  used <- data.frame(unlist(subset(used, used==TRUE)))
  
  DTM <- subset(DTM, select = rownames(used))
  
  # transform date to seconds since Jan 1 2020 midnight
  func<- function(x) { difftime(x, lubridate::as_datetime("2020-01-01 00:00:00 UTC"), units="secs")}
  DTM<-as.data.frame(apply(DTM,MARGIN = 2, func))
  
  #make sure there are no rows with all NA's, remove if so
  table(rowSums(is.na(DTM))!=ncol(DTM)) #false = number w/ all NAs
  DTM<-DTM[rowSums(is.na(DTM))!=ncol(DTM), ]
  
  # create vector of row names
  node_names <- unique(row.names(DTM))
  
  # Calculate cascade of tokens
  cascades <- as_cascade_wide(DTM, node_names = node_names)
  
  # run algorithm
  result <- netinf(cascades, trans_mod = "exponential",  p_value_cutoff = pcut)
  
  # get edge list
  edges<-as.data.frame(table(result[,1]))
  names(edges) <- c("id","outdegree")
  
  # return objects of interest
  return(list(output = result, edgelist = edges, 
              casc = cascades, n = ntweets, 
              dims = finaldims))
}

cascade.first.use <- function(dat, 
                              text_col, doc_col, 
                              to_compound,
                              ngrams = NULL,
                              stem = FALSE,
                              term_min = .01,
                              pcut = .01, 
                              seed = 12345){
  require(NetworkInference)
  require(quanteda)
  require(data.table)
  require(tidyverse)
  
  set.seed(seed)
  
  # get count of tweets
  ntweets <- dat %>% 
    group_by(id) %>% 
    summarise(n = n())
  
  message("Making corpus")
  corp <- quanteda::corpus(dat, text_field= text_col, docid_field= doc_col)
  
  message("Setting compounds...")
  
  compounds <- quanteda::phrase(
    c("white house","fox news","united states",
      "united_nations",
      "social media","climate change","right wing",
      to_compound)
  )
  
  message("Tokenizing...")
  
  # tokenize/preprocess
  tok <- quanteda::tokens(corp,
                          remove_punct=T,
                          remove_twitter=T, 
                          remove_numbers = T,
                          remove_url=T) %>%
    quanteda::tokens_remove("\\p{Z}", valuetype = "regex") %>%
    quanteda::tokens_remove("via", valuetype = "regex") %>%
    
    # remove stop words
    quanteda::tokens_remove(quanteda::stopwords(source = "snowball")) %>%
    
    # concatenate common bigrams
    quanteda::tokens_compound(pattern = compounds)
  
  if(stem == TRUE){
    message("Stemming...")
    # stem
    tok <- tok %>% quanteda::tokens_wordstem()
  }
  if(!is.null(ngrams)){
    message(paste0("ngramming to n = ", ngrams, "..."))
    
    # ngram
    tok <- tok %>%  quanteda::tokens_ngrams(ngrams)
  }
  
  # make full document frequency matrix
  dfm.full <-  dfm(tok,  verbose=T,tolower = T)
  
  # discard very frequent/infrequent terms
  message("Trimming dfm...")
  
  dfm.lim <- dfm_trim(
    dfm.full, 
    min_docfreq = term_min, 
    max_docfreq = .5,
    docfreq_type = "prop")
  
  # put in matrix and store dimensions
  dfm_mat <- as.matrix(dfm.lim)
  finaldims <- dim(dfm_mat)
  
  # get timestamps
  dts <- dat$created_at[match(rownames(dfm_mat), dat$index)]
  
  # replace dfm cells of 0 with NA's
  dfm_mat[as.numeric(as.character(dfm_mat))==0]<-NA
  
  # replace non-NAs with timestamp for when the term was used
  for(i in 1:nrow(dfm_mat)){
    dfm_mat[i,] <- ifelse(is.na(dfm_mat[i,]), NA,dts[i])
  }
  
  # remove columns with all NA's
  dfm_mat <- dfm_mat[,colSums(is.na(dfm_mat))<nrow(dfm_mat)] #remove rows with all NA
  
  # add names to matrix
  dfm_mat <- as.matrix(cbind.data.frame(user=corp$documents$id,
                                        dfm_mat))
  
  
  # FIND FIRST USAGE OF TERM BY EACH MoC
  DTM<-as.data.frame(
    matrix(ncol=ncol(dfm_mat),
           nrow=length(unique(dfm_mat[,1]))))
  
  colnames(DTM)<-colnames(dfm_mat)
  DTM[,1]<-unique(dfm_mat[,1])
  
  for(i in 2:ncol(dfm_mat)){
    d = aggregate(dfm_mat[,i] ~ dfm_mat[,1],FUN=min)
    DTM[,i]<- d[,2][match(DTM[,1],d[,1])]
  }
  
  # rename rows and remove actor name column
  rownames(DTM)<-DTM[,1]
  DTM$user<-NULL
  
  
  #change class from numeric to date
  DTM[ , 1:ncol(DTM)] <- lapply(DTM[ , 1:ncol(DTM)], function(x) lubridate::as_datetime(as.numeric(x)))
  
  # include only terms used by at least 3 members (netinf tracks diffusion of tokens,
  # which can't diffuse if used by only 1 member; 3+ is a slightly more stringent indicator of importance)
  used <- lapply(DTM,function(col) length(unique(col))>=3)
  used <- data.frame(unlist(subset(used, used==TRUE)))
  
  DTM <- subset(DTM, select = rownames(used))
  
  return(DTM)
}

prep.cascadeplot <- function(cascade, token, mergedat = member_meta){
  id <- as.character(unlist(cascade$casc$cascade_nodes[which(names(cascade$casc$cascade_nodes) == token)]))
  tim <- lubridate::as_datetime(unlist(cascade$casc$cascade_times[which(names(cascade$casc$cascade_times) == token)]))
  
  outd <- cascade$n
  
  df <- data.frame(id = id,
                   times =   lubridate::as_datetime(tim)) %>%
    left_join(member_meta[,c("id","lastname", "nominate.dim1","party_code")], by = "id") %>%
    left_join(outd, by = "id")
}

plot.cascade <- function(cascade, token, topic = "COVID-19", geom = "text"){
  require(ggrepel)
  id <- as.character(unlist(cascade$casc$cascade_nodes[which(names(cascade$casc$cascade_nodes) == token)]))
  tim <- lubridate::as_datetime(unlist(cascade$casc$cascade_times[which(names(cascade$casc$cascade_times) == token)]))
  
  outd <- cascade$n
  
  df <- data.frame(id = id,
                   times =   lubridate::as_datetime(tim)) %>%
    left_join(member_meta[,c("id","lastname", "nominate.dim1","party_code")], by = "id") %>%
    left_join(outd, by = "id")
  
  if(geom == "point"){
    p <- 
      df %>% ggplot(aes(x = times, y = nominate.dim1,
                        label = lastname, size = n))+
      geom_point()+
      scale_size_continuous(name = paste0("n tweets about ", topic))+
      labs(title = paste0("Members' first use of token '", token, "' in tweets about ", topic),
           x = "Date", y = "NOMINATE First Dimension")+
      theme_bw()+
      theme(plot.title = element_text(size = 24))
    return(p)
  }else{
    p <- 
      df %>% ggplot(aes(x = times, y = nominate.dim1,
                        label = lastname))+
      geom_text_repel()+
      labs(title = paste0("Members' first use of token '", token, "' in tweets about ", topic),
           x = "Date", y = "NOMINATE First Dimension")+
      theme_bw()+
      theme(plot.title = element_text(size = 24))
    return(p)
  }
}
